import os
import sys
import gzip
import pysam
from numpy.random import permutation
from Bio import SeqIO


dataset = sys.argv[1]

assembly = "hg38"

def read_chromosome_sizes(assembly):
    directory = "/osc-fs_home/scratch/mdehoon/Data/Genomes"
    filename = "%s.chrom.sizes" % assembly
    path = os.path.join(directory, assembly, filename)
    handle = open(path)
    chromosomes = []
    sizes = []
    for line in handle:
        chromosome, size = line.split()
        if chromosome.endswith("_alt"):
            continue
        chromosomes.append(chromosome)
        sizes.append(int(size))
    handle.close()
    return chromosomes, sizes

if dataset == "HiSeq":
    fmt = "fastq"
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/%s/Fastq/" % dataset
elif dataset == "CAGE":
    fmt = "fasta"
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/%s/Fasta/" % dataset

filenames = os.listdir(directory)
filenames.sort()
n = 0
for filename in filenames:
    path = os.path.join(directory, filename)
    if dataset == "HiSeq":
        library, fq, gz = filename.split(".")
        assert fq == "fq"
        assert gz == "gz"
        if library == "t01_r3":  # Sample negative control
            print("Skipping %s" % path)
            continue
    print("Reading %s" % path)
    stream = gzip.open(path, "rt")
    records = SeqIO.parse(stream, fmt)
    for record in records:
        n += 1
    stream.close()
    print(n)

print("%s: %d sequences in total" % (dataset, n))

indices = permutation(n)

chromosomes, sizes = read_chromosome_sizes(assembly)

filename = "%s.bam" % dataset
print("Writing", filename)
output = pysam.AlignmentFile(filename, "wb", reference_names=chromosomes, reference_lengths=sizes)

directory = "/osc-fs_home/mdehoon/Data/CASPARs/%s/Mapping/" % dataset

filenames = os.listdir(directory)
filenames.sort()
i = 0
query_name = None
for filename in filenames:
    if dataset == "HiSeq":
        library, bam = filename.split(".")
        assert bam == "bam"
        if library == "t01_r3":  # Sample negative control
            print("Skipping %s" % library)
            continue
    path = os.path.join(directory, filename)
    print("Reading %s" % path)
    alignments = pysam.AlignmentFile(path)
    for alignment in alignments:
        if alignment.query_name != query_name:
            index = indices[i]
            i += 1
            query_name = alignment.query_name
        alignment.set_tag("RX", index, "i")
        output.write(alignment)

output.close()

assert i == n
